%% interventionMatrix.m
%
% Solve for Intervention Matrix from adjustment policy functions
%   Note: separate matrix over 4 state variables: {b,a,y,rm} for each rb
%--------------------------------------------------------------------------

% Preliminaries
Mall = Nb*Na*Nz*Nr;
pointslistAll = (1:Mall)';
pointslist = (1:M)';  

%--------------------------------------------------------------------------
%% Build intervention matrix, denoted C
for rb_ind = 1:Nr
    rb_spot_ind = rb_ind;
    
    %Get adjustment regions/targets
    adj = [];
    for rm_ind = 1:Nr
        adj = [adj; adjStore{rb_spot_ind,rm_ind}];
    end
    adj = logical(adj);
    
    adjpoints = pointslistAll(adj);
    noadjpoints = pointslistAll(~adj);

    C = sparse(noadjpoints, noadjpoints, ones(length(noadjpoints),1), Mall, Mall);
    pointsmat = reshape(pointslist, Nb, Na, Nz);
    count_idwpoints = 0; % count times inverse distance weights used to allocate mass
    for mcount = 1:length(adjpoints)
        mStore = adjpoints(mcount);
            rm_ind = floor((mStore-1)/M)+1;
            m = mod(mStore-1,Nb*Na*Nz)+1; % decouple from r state variable

        zind = ceil(m/(Nb*Na));
        ab_givenz = mod(m-1,Nb*Na)+1;

        bAdj_combine = bAdj_combineStore{rb_spot_ind,rm_ind};
        aAdj_combine = aAdj_combineStore{rb_spot_ind,rm_ind};
        rewrite_combine = rewrite_combineStore{rb_spot_ind,rm_ind};

        bprime = bAdj_combine(ab_givenz,zind);
        aprime = aAdj_combine(ab_givenz,zind);
        rewrite_flag = rewrite_combine(ab_givenz,zind);
        post_adj_rmind = rb_spot_ind*rewrite_flag + rm_ind*(1-rewrite_flag);

        bprime_left = discretize(round(bprime,13), b);
            if bprime >= bmax
                bprime_left = Nb-1;
            end
        aprime_left = discretize(round(aprime,13), a);
        bprime_right = bprime_left + 1;
        aprime_right = aprime_left + 1;

        %Map to post-adjustment region
        pointLL = (post_adj_rmind-1)*M + pointsmat(bprime_left, aprime_left, zind);
        pointLR = (post_adj_rmind-1)*M + pointsmat(bprime_left, aprime_right, zind);
        pointRL = (post_adj_rmind-1)*M + pointsmat(bprime_right, aprime_left, zind);
        pointRR = (post_adj_rmind-1)*M + pointsmat(bprime_right, aprime_right, zind);

        neighpoints = [pointLL pointLR;
                       pointRL pointRR];
        % check if each point is not an adjustment point
        pointscheck = reshape(ismember(neighpoints(:), adjpoints), 2, 2); % 0 = not an adjustment point, so keep it
        pointscheck = 1 - pointscheck;
        if sum(sum(pointscheck)) == 4 && BilinW == 1 % use weights proportional to area opposite the point
            totarea = (b(bprime_right) - b(bprime_left)) * (a(aprime_right) - a(aprime_left));
            weights = totarea^(-1) * [b(bprime_right) - bprime; bprime - b(bprime_left)] * [a(aprime_right) - aprime aprime - a(aprime_left)];
        else
            count_idwpoints = count_idwpoints + 1; % use inverse distance weights
            neigh_bvals = [ones(1,2) * b(bprime_left);
                           ones(1,2) * b(bprime_right)];
            neigh_avals = [ones(2,1) * a(aprime_left) ones(2,1) * a(aprime_right)];
            bprime_mat = ones(2,2) * bprime;
            aprime_mat = ones(2,2) * aprime;
            dist_points = sqrt((neigh_bvals - bprime_mat).^2 + (neigh_avals - aprime_mat).^2); % Euclidean distance
            inverse_dist = (1 ./ dist_points) .* pointscheck; % remove the points which are adjustment points
            if sum(sum(isinf(inverse_dist))) == 1
                weights = isinf(inverse_dist) .* ones(2,2);
            else
                weights = inverse_dist / sum(sum(inverse_dist));
            end
        end
        C = C + sparse(mStore * ones(4, 1), neighpoints(:),  weights(:), Mall, Mall);
    end
    display(['rb = ', num2str(rb_ind), '. Number of points where IDWs were used = ', num2str(count_idwpoints)])
    
    interventionStore{rb_ind} = C;
    
    
    %----------------------------------------------------------------------
    %% Also create matrix for forced refis at all points
    
    %Adjustment points (can use C as starting point)
    FF = C;
    FF = FF - sparse(noadjpoints, noadjpoints, ones(length(noadjpoints),1), Mall, Mall);
    FF = la_forced*FF;
    FF = FF - sparse(adjpoints, adjpoints, la_forced*ones(length(adjpoints),1), Mall, Mall);

    %Non-adjustment points
    FF = FF - sparse(noadjpoints, noadjpoints, la_forced*ones(length(noadjpoints),1), Mall, Mall);
        noAdjAllowed = find(repmat(cantAdjust(:),Nr,1)==1);
        FF(noAdjAllowed,:)=0;
    for mcount = 1:length(noadjpoints)        
        mStore = noadjpoints(mcount);
            rm_ind = floor((mStore-1)/M)+1;
            m = mod(mStore-1,Nb*Na*Nz)+1; % decouple from r state variable

        %Account for points where adjustment not allowed
        if ismember(mStore, noAdjAllowed)
            continue
        end
            
        zind = ceil(m/(Nb*Na));
        ab_givenz = mod(m-1,Nb*Na)+1;

        bAdj_combine = bAdj_combineStore{rb_spot_ind,rm_ind};
        aAdj_combine = aAdj_combineStore{rb_spot_ind,rm_ind};
        rewrite_combine = rewrite_combineStore{rb_spot_ind,rm_ind};

        bprime = bAdj_combine(ab_givenz,zind);
        aprime = aAdj_combine(ab_givenz,zind);
        rewrite_flag = rewrite_combine(ab_givenz,zind);
        post_adj_rmind = rb_spot_ind*rewrite_flag + rm_ind*(1-rewrite_flag);

        bprime_left = discretize(round(bprime,13), b);
            if bprime >= bmax
                bprime_left = Nb-1;
            end
        aprime_left = discretize(round(aprime,13), a);
        bprime_right = bprime_left + 1;
        aprime_right = aprime_left + 1;

        %Map to post-adjustment region
        pointLL = (post_adj_rmind-1)*M + pointsmat(bprime_left, aprime_left, zind);
        pointLR = (post_adj_rmind-1)*M + pointsmat(bprime_left, aprime_right, zind);
        pointRL = (post_adj_rmind-1)*M + pointsmat(bprime_right, aprime_left, zind);
        pointRR = (post_adj_rmind-1)*M + pointsmat(bprime_right, aprime_right, zind);

        neighpoints = [pointLL pointLR;
                       pointRL pointRR];
        % check if each point is not an adjustment point
        pointscheck = reshape(ismember(neighpoints(:), adjpoints), 2, 2); % 0 = not an adjustment point, so keep it
        pointscheck = 1 - pointscheck;
        if sum(sum(pointscheck)) == 4 && BilinW == 1 % use weights proportional to area opposite the point
            totarea = (b(bprime_right) - b(bprime_left)) * (a(aprime_right) - a(aprime_left));
            weights = totarea^(-1) * [b(bprime_right) - bprime; bprime - b(bprime_left)] * [a(aprime_right) - aprime aprime - a(aprime_left)];
        else
            count_idwpoints = count_idwpoints + 1; % use inverse distance weights
            neigh_bvals = [ones(1,2) * b(bprime_left);
                         ones(1,2) * b(bprime_right)];
            neigh_avals = [ones(2,1) * a(aprime_left) ones(2,1) * a(aprime_right)];
            bprime_mat = ones(2,2) * bprime;
            aprime_mat = ones(2,2) * aprime;
            dist_points = sqrt((neigh_bvals - bprime_mat).^2 + (neigh_avals - aprime_mat).^2); % Euclidean distance
            inverse_dist = (1 ./ dist_points) .* pointscheck; % remove the points which are adjustment points
            if sum(sum(isinf(inverse_dist))) == 1
                weights = isinf(inverse_dist) .* ones(2,2);
            else
                weights = inverse_dist / sum(sum(inverse_dist));
            end
        end
        FF = FF + sparse(mStore * ones(4, 1), neighpoints(:),  la_forced*weights(:), Mall, Mall);
    end
    forcedRefiStore{rb_ind} = FF;

    
    %----------------------------------------------------------------------
    %% Create matrix for ARM adjustments
    if ARMFlag == 1
        %Adjust to same place in b-a-z state space, but with different rm
        ARMARM = sparse([1:Nb*Na*Nz*Nr], repmat([Nb*Na*Nz*(rb_ind-1)+1:Nb*Na*Nz*rb_ind]', 4,1), ones(length(Nb*Na*Nz*Nr),1), Mall, Mall);        
        ARMStore{rb_ind} = ARMARM;
    end
end